{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prov-GigaPath Demo\n",
    "\n",
    "This notebook provides a quick walkthrough of the Prov-GigaPath models. We will start by demonstrating how to download the Prov-GigaPath models from HuggingFace. Next, we will show an example of pre-processing a slide. Finally, we will demonstrate how to run Prov-GigaPath on the sample slide.\n",
    "\n",
    "### Prepare HF Token\n",
    "\n",
    "To begin, please request access to the model from our HuggingFace repository: https://huggingface.co/prov-gigapath/prov-gigapath.\n",
    "\n",
    "Once approved, set the HF_TOKEN to access the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Please set your Hugging Face API token\n",
    "os.environ[\"HF_TOKEN\"] = \"YOUR_HF_TOKEN\"\n",
    "\n",
    "assert \"HF_TOKEN\" in os.environ, \"Please set the HF_TOKEN environment variable to your Hugging Face API token\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download a sample slide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/naotous/miniconda3/envs/gigapath/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "PROV-000-000001.ndpi: 100%|██████████| 424M/424M [00:02<00:00, 199MB/s]  \n"
     ]
    }
   ],
   "source": [
    "import huggingface_hub\n",
    "\n",
    "local_dir = os.path.join(os.path.expanduser(\"~\"), \".cache/\")\n",
    "huggingface_hub.hf_hub_download(\"prov-gigapath/prov-gigapath\", filename=\"sample_data/PROV-000-000001.ndpi\", local_dir=local_dir, force_download=True)\n",
    "slide_path = os.path.join(local_dir, \"sample_data/PROV-000-000001.ndpi\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tiling\n",
    "\n",
    "Whole-slide images are giga-pixel in size. To efficiently process these enormous images, we use a tiling technique that divides them into smaller, more manageable tile images. As an example, we demonstrate how to process a single slide.\n",
    "\n",
    "NOTE: Prov-GigaPath is trained with slides preprocessed at 0.5 MPP. Ensure that you use the appropriate level for the 0.5 MPP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing slide /home/naotous/.cache/sample_data/PROV-000-000001.ndpi at level 1 with tile size 256. Saving to outputs/preprocessing.\n",
      "('slide_id', 'tile_id', 'image', 'label', 'tile_x', 'tile_y', 'occupancy')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Tiles (PROV-0…): 100%|██████████| 1068/1068 [00:17<00:00, 59.57img/s]\n"
     ]
    }
   ],
   "source": [
    "from gigapath.pipeline import tile_one_slide\n",
    "\n",
    "tmp_dir = 'outputs/preprocessing/'\n",
    "tile_one_slide(slide_path, save_dir=tmp_dir, level=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the tile images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1068 image tiles\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# load image tiles\n",
    "slide_dir = \"outputs/preprocessing/output/\" + os.path.basename(slide_path) + \"/\"\n",
    "image_paths = [os.path.join(slide_dir, img) for img in os.listdir(slide_dir) if img.endswith('.png')]\n",
    "\n",
    "print(f\"Found {len(image_paths)} image tiles\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the Prov-GigaPath model (tile and slide encoder models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tile encoder param # 1134953984\n",
      "dilated_ratio:  [1, 2, 4, 8, 16]\n",
      "segment_length:  [1024, 5792, 32768, 185363, 1048576]\n",
      "Number of trainable LongNet parameters:  85148160\n",
      "Global Pooling: True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "slide_encoder.pth: 100%|██████████| 345M/345M [00:01<00:00, 235MB/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m Successfully Loaded Pretrained GigaPath model from hf_hub:prov-gigapath/prov-gigapath \u001b[00m\n",
      "Slide encoder param # 86330880\n"
     ]
    }
   ],
   "source": [
    "from gigapath.pipeline import load_tile_slide_encoder\n",
    "\n",
    "# Load the tile and slide encoder models\n",
    "# NOTE: The CLS token is not trained during the slide-level pretraining.\n",
    "# Here, we enable the use of global pooling for the output embeddings.\n",
    "tile_encoder, slide_encoder_model = load_tile_slide_encoder(global_pool=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run tile-level inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running inference with tile encoder: 100%|██████████| 9/9 [00:09<00:00,  1.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tile_encoder_outputs[tile_embeds].shape: torch.Size([1068, 1536])\n",
      "tile_encoder_outputs[coords].shape: torch.Size([1068, 2])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from gigapath.pipeline import run_inference_with_tile_encoder\n",
    "\n",
    "tile_encoder_outputs = run_inference_with_tile_encoder(image_paths, tile_encoder)\n",
    "\n",
    "for k in tile_encoder_outputs.keys():\n",
    "    print(f\"tile_encoder_outputs[{k}].shape: {tile_encoder_outputs[k].shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run slide-level inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['layer_0_embed', 'layer_1_embed', 'layer_2_embed', 'layer_3_embed', 'layer_4_embed', 'layer_5_embed', 'layer_6_embed', 'layer_7_embed', 'layer_8_embed', 'layer_9_embed', 'layer_10_embed', 'layer_11_embed', 'layer_12_embed', 'last_layer_embed'])\n"
     ]
    }
   ],
   "source": [
    "from gigapath.pipeline import run_inference_with_slide_encoder\n",
    "# run inference with the slide encoder\n",
    "slide_embeds = run_inference_with_slide_encoder(slide_encoder_model=slide_encoder_model, **tile_encoder_outputs)\n",
    "print(slide_embeds.keys())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gigapath",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
